{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true,
          "base_uri": "https://localhost:8080/"
        },
        "id": "D2H6YLCPOXvZ",
        "outputId": "fe45d7a2-e148-49fc-ca4b-d6e84b0049c0"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Requirement already satisfied: transformers==4.48.2 in /usr/local/lib/python3.12/dist-packages (4.48.2)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (3.20.2)\n",
            "Requirement already satisfied: huggingface-hub<1.0,>=0.24.0 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (0.36.0)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (2.0.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (25.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (6.0.3)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (2025.11.3)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (2.32.4)\n",
            "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (0.21.4)\n",
            "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (0.7.0)\n",
            "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.12/dist-packages (from transformers==4.48.2) (4.67.1)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.2) (2024.5.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.2) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub<1.0,>=0.24.0->transformers==4.48.2) (1.2.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.48.2) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.48.2) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.48.2) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->transformers==4.48.2) (2026.1.4)\n",
            "Requirement already satisfied: datasets==2.20.0 in /usr/local/lib/python3.12/dist-packages (2.20.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (3.20.2)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (2.0.2)\n",
            "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (18.1.0)\n",
            "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (0.7)\n",
            "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (0.3.8)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (2.2.2)\n",
            "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (2.32.4)\n",
            "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (4.67.1)\n",
            "Requirement already satisfied: xxhash in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (3.6.0)\n",
            "Requirement already satisfied: multiprocess in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (0.70.16)\n",
            "Requirement already satisfied: fsspec<=2024.5.0,>=2023.1.0 in /usr/local/lib/python3.12/dist-packages (from fsspec[http]<=2024.5.0,>=2023.1.0->datasets==2.20.0) (2024.5.0)\n",
            "Requirement already satisfied: aiohttp in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (3.13.3)\n",
            "Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (0.36.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (25.0)\n",
            "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.12/dist-packages (from datasets==2.20.0) (6.0.3)\n",
            "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (2.6.1)\n",
            "Requirement already satisfied: aiosignal>=1.4.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (1.4.0)\n",
            "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (25.4.0)\n",
            "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (1.8.0)\n",
            "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (6.7.0)\n",
            "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (0.4.1)\n",
            "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.12/dist-packages (from aiohttp->datasets==2.20.0) (1.22.0)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.21.2->datasets==2.20.0) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface-hub>=0.21.2->datasets==2.20.0) (1.2.0)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets==2.20.0) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets==2.20.0) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets==2.20.0) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.32.2->datasets==2.20.0) (2026.1.4)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.20.0) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.20.0) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas->datasets==2.20.0) (2025.3)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas->datasets==2.20.0) (1.17.0)\n",
            "\n",
            "Usage:   \n",
            "  pip3 install [options] <requirement specifier> [package-index-options] ...\n",
            "  pip3 install [options] -r <requirements file> [package-index-options] ...\n",
            "  pip3 install [options] [-e] <vcs project url> ...\n",
            "  pip3 install [options] [-e] <local project path> ...\n",
            "  pip3 install [options] <archive url/path> ...\n",
            "\n",
            "no such option: --uprade\n",
            "Requirement already satisfied: peft==0.17.1 in /usr/local/lib/python3.12/dist-packages (0.17.1)\n",
            "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (2.0.2)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (25.0)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (5.9.5)\n",
            "Requirement already satisfied: pyyaml in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (6.0.3)\n",
            "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (2.9.0+cu126)\n",
            "Requirement already satisfied: transformers in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (4.48.2)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (4.67.1)\n",
            "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (1.12.0)\n",
            "Requirement already satisfied: safetensors in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (0.7.0)\n",
            "Requirement already satisfied: huggingface_hub>=0.25.0 in /usr/local/lib/python3.12/dist-packages (from peft==0.17.1) (0.36.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (3.20.2)\n",
            "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (2024.5.0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (2.32.4)\n",
            "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (4.15.0)\n",
            "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /usr/local/lib/python3.12/dist-packages (from huggingface_hub>=0.25.0->peft==0.17.1) (1.2.0)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (75.2.0)\n",
            "Requirement already satisfied: sympy>=1.13.3 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (1.14.0)\n",
            "Requirement already satisfied: networkx>=2.5.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (3.6.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (3.1.6)\n",
            "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.77)\n",
            "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.6.80 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.80)\n",
            "Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (9.10.2.21)\n",
            "Requirement already satisfied: nvidia-cublas-cu12==12.6.4.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.4.1)\n",
            "Requirement already satisfied: nvidia-cufft-cu12==11.3.0.4 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (11.3.0.4)\n",
            "Requirement already satisfied: nvidia-curand-cu12==10.3.7.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (10.3.7.77)\n",
            "Requirement already satisfied: nvidia-cusolver-cu12==11.7.1.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (11.7.1.2)\n",
            "Requirement already satisfied: nvidia-cusparse-cu12==12.5.4.2 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.5.4.2)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (0.7.1)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.27.5 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (2.27.5)\n",
            "Requirement already satisfied: nvidia-nvshmem-cu12==3.3.20 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (3.3.20)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.6.77 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.77)\n",
            "Requirement already satisfied: nvidia-nvjitlink-cu12==12.6.85 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (12.6.85)\n",
            "Requirement already satisfied: nvidia-cufile-cu12==1.11.1.6 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (1.11.1.6)\n",
            "Requirement already satisfied: triton==3.5.0 in /usr/local/lib/python3.12/dist-packages (from torch>=1.13.0->peft==0.17.1) (3.5.0)\n",
            "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.12/dist-packages (from transformers->peft==0.17.1) (2025.11.3)\n",
            "Requirement already satisfied: tokenizers<0.22,>=0.21 in /usr/local/lib/python3.12/dist-packages (from transformers->peft==0.17.1) (0.21.4)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.12/dist-packages (from sympy>=1.13.3->torch>=1.13.0->peft==0.17.1) (1.3.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->torch>=1.13.0->peft==0.17.1) (3.0.3)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub>=0.25.0->peft==0.17.1) (3.4.4)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub>=0.25.0->peft==0.17.1) (3.11)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub>=0.25.0->peft==0.17.1) (2.5.0)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests->huggingface_hub>=0.25.0->peft==0.17.1) (2026.1.4)\n",
            "Downloading...\n",
            "From: https://drive.google.com/uc?id=1Bo2RfzdgIhMl2wtT5lB0N4m-Wck73HRY\n",
            "To: /content/gpt-translations.csv\n",
            "100% 6.84M/6.84M [00:00<00:00, 204MB/s]\n"
          ]
        }
      ],
      "source": [
        "!pip install transformers==4.48.2\n",
        "!pip install datasets==2.20.0\n",
        "!pip install --uprade --no-cache-dir gdown==4.5.4\n",
        "!pip install peft==0.17.1\n",
        "\n",
        "!gdown 1Bo2RfzdgIhMl2wtT5lB0N4m-Wck73HRY"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "U3ahFqA9SgjS"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "\n",
        "import pandas as pd\n",
        "\n",
        "def is_positive(x):\n",
        "\treturn 1 if sum(json.loads(x).values()) > 0 else 0\n",
        "\n",
        "data = pd.read_csv(\"gpt_translations.csv\")\n",
        "data[\"text\"] = data[\"gpt_translation\"] # overwrite the German text\n",
        "\n",
        "data[\"is_left_wing\"] = data[\"left_wing\"].apply(is_positive)\n",
        "data[\"is_right_wing\"] = data[\"right_wing\"].apply(is_positive)\n",
        "data[\"is_anti_elitism\"] = data[\"anti_elitism\"].apply(is_positive)\n",
        "data[\"is_people_centrism\"] = data[\"people_centrism\"].apply(is_positive)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true,
          "referenced_widgets": [
            "7d976182e9c3445c8831d813903d8db1",
            "f56f690a02df4644bc9be2b681632bb7",
            "ef1b0daeaa87480ca81538fd806bdef5",
            "4a1d7344d02d4ac7a2f004dd2148af41",
            "76d85292981746fdbf5c6285d82a8260"
          ]
        },
        "id": "9JgpFVFJhh-k",
        "outputId": "32699d9f-57fc-47d0-e366-3a2d94e4c533"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_auth.py:104: UserWarning: \n",
            "Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.\n",
            "You are not authenticated with the Hugging Face Hub in this notebook.\n",
            "If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "7d976182e9c3445c8831d813903d8db1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f56f690a02df4644bc9be2b681632bb7",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ef1b0daeaa87480ca81538fd806bdef5",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4a1d7344d02d4ac7a2f004dd2148af41",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "76d85292981746fdbf5c6285d82a8260",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from collections import defaultdict\n",
        "import csv\n",
        "import random\n",
        "import time\n",
        "\n",
        "from datasets import load_metric\n",
        "import numpy as np\n",
        "from sklearn.model_selection import train_test_split\n",
        "import torch\n",
        "from transformers import RobertaTokenizerFast\n",
        "\n",
        "mlength = 512\n",
        "nclasses = 4\n",
        "start = time.time()\n",
        "tokenizer = RobertaTokenizerFast.from_pretrained('roberta-large')\n",
        "\n",
        "class PSCDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, encodings, labels):\n",
        "        self.encodings = encodings\n",
        "        self.labels = labels\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
        "        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.float)\n",
        "        return item\n",
        "\n",
        "    def __len__(self):\n",
        "        return len(self.labels)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "eGj5f_Tzfw0N"
      },
      "outputs": [],
      "source": [
        "def metric_helper(labels, logits):\n",
        "    labels = np.array(labels)  # Convert labels to NumPy array\n",
        "    logits = np.array(logits)  # Convert logits to NumPy array\n",
        "    l1_precision, l1_recall, l1_thresholds = precision_recall_curve(labels[:, 0], logits[:, 0])\n",
        "    l1_f1 = 2 * l1_precision * l1_recall / (l1_precision + l1_recall)\n",
        "    max_l1_f1 = max(l1_f1)\n",
        "    i = np.where(l1_f1 == max_l1_f1)[0][0]\n",
        "    max_l1_precision = l1_precision[i]\n",
        "    max_l1_recall = l1_recall[i]\n",
        "    max_l1_threshold = l1_thresholds[i]\n",
        "\n",
        "\n",
        "    l2_precision, l2_recall, l2_thresholds = precision_recall_curve(labels[:, 1], logits[:, 1])\n",
        "    l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
        "    max_l2_f1 = max(l2_f1)\n",
        "    i = np.where(l2_f1 == max_l2_f1)[0][0]\n",
        "    max_l2_precision = l2_precision[i]\n",
        "    max_l2_recall = l2_recall[i]\n",
        "    max_l2_threshold = l2_thresholds[i]\n",
        "\n",
        "    l3_precision, l3_recall, l3_thresholds = precision_recall_curve(labels[:, 2], logits[:, 2])\n",
        "    l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n",
        "    max_l3_f1 = max(l3_f1)\n",
        "    i = np.where(l3_f1 == max_l3_f1)[0][0]\n",
        "    max_l3_precision = l3_precision[i]\n",
        "    max_l3_recall = l3_recall[i]\n",
        "    max_l3_threshold = l3_thresholds[i]\n",
        "\n",
        "    l4_precision, l4_recall, l4_thresholds = precision_recall_curve(labels[:, 3], logits[:, 3])\n",
        "    l4_f1 = 2 * l4_precision * l4_recall / (l4_precision + l4_recall)\n",
        "    max_l4_f1 = max(l4_f1)\n",
        "    i = np.where(l4_f1 == max_l4_f1)[0][0]\n",
        "    max_l4_precision = l4_precision[i]\n",
        "    max_l4_recall = l4_recall[i]\n",
        "    max_l4_threshold = l4_thresholds[i]\n",
        "\n",
        "    macro_precision = (max_l1_precision + max_l2_precision + max_l3_precision + max_l4_precision) / 4\n",
        "    macro_recall = (max_l1_recall + max_l2_recall + max_l3_recall + max_l4_recall) / 4\n",
        "    macro_f1 = (max_l1_f1 + max_l2_f1 + max_l3_f1 + max_l4_f1) / 4\n",
        "\n",
        "\n",
        "    # Generate predictions using the max F1 threshold\n",
        "    predicted = (logits >= [max_l1_threshold, max_l2_threshold, max_l3_threshold, max_l4_threshold]).astype(int)\n",
        "\n",
        "    # Calculate TP, FP, FN, TN\n",
        "    tp = np.sum((predicted == 1) & (labels == 1))  # True Positives\n",
        "    fp = np.sum((predicted == 1) & (labels == 0))  # False Positives\n",
        "    fn = np.sum((predicted == 0) & (labels == 1))  # False Negatives\n",
        "    tn = np.sum((predicted == 0) & (labels == 0))  # True Negatives\n",
        "    micro_precision = tp/(tp+fp)\n",
        "    micro_recall = tp/(tp+fn)\n",
        "    micro_f1 = 2 * micro_precision * micro_recall / (micro_precision + micro_recall)\n",
        "\n",
        "    # return 4 * 3 + 3 + 3 = 18 values\n",
        "    return max_l1_precision, max_l1_recall, max_l1_f1, max_l2_precision, max_l2_recall, max_l2_f1, max_l3_precision, max_l3_recall, max_l3_f1, max_l4_precision, max_l4_recall, max_l4_f1, micro_precision, micro_recall, micro_f1, macro_precision, macro_recall, macro_f1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true,
          "referenced_widgets": [
            "8194f20f4fcf4f30a641a5acc289e50a"
          ]
        },
        "id": "lxfft6tepJZj",
        "outputId": "fd59ee35-fac3-4f9e-c2d4-dd6d619aefc5"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "8194f20f4fcf4f30a641a5acc289e50a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 07:59, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.279600</td>\n",
              "      <td>0.279439</td>\n",
              "      <td>0.707865</td>\n",
              "      <td>0.597403</td>\n",
              "      <td>0.816754</td>\n",
              "      <td>0.612440</td>\n",
              "      <td>0.683615</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.195400</td>\n",
              "      <td>0.286391</td>\n",
              "      <td>0.730964</td>\n",
              "      <td>0.666667</td>\n",
              "      <td>0.819149</td>\n",
              "      <td>0.674419</td>\n",
              "      <td>0.722800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.082500</td>\n",
              "      <td>0.298779</td>\n",
              "      <td>0.750000</td>\n",
              "      <td>0.750000</td>\n",
              "      <td>0.833747</td>\n",
              "      <td>0.691892</td>\n",
              "      <td>0.756410</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.7564096975387298 2e-05 11\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.71, 0.72, 0.71\n",
            "Right wing: 0.72, 0.71, 0.72\n",
            "Anti elitism: 0.80, 0.88, 0.84\n",
            "People centrism: 0.64, 0.81, 0.71\n",
            "Micro: 0.73, 0.82, 0.77\n",
            "Macro: 0.72, 0.78, 0.74\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:02, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.279200</td>\n",
              "      <td>0.272584</td>\n",
              "      <td>0.709677</td>\n",
              "      <td>0.680000</td>\n",
              "      <td>0.796339</td>\n",
              "      <td>0.634409</td>\n",
              "      <td>0.705106</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.191000</td>\n",
              "      <td>0.266766</td>\n",
              "      <td>0.748466</td>\n",
              "      <td>0.658537</td>\n",
              "      <td>0.818942</td>\n",
              "      <td>0.666667</td>\n",
              "      <td>0.723153</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.079600</td>\n",
              "      <td>0.316034</td>\n",
              "      <td>0.736264</td>\n",
              "      <td>0.717391</td>\n",
              "      <td>0.826316</td>\n",
              "      <td>0.670051</td>\n",
              "      <td>0.737505</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:03, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.257800</td>\n",
              "      <td>0.269593</td>\n",
              "      <td>0.695187</td>\n",
              "      <td>0.625000</td>\n",
              "      <td>0.825737</td>\n",
              "      <td>0.640394</td>\n",
              "      <td>0.696580</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.176700</td>\n",
              "      <td>0.276311</td>\n",
              "      <td>0.735135</td>\n",
              "      <td>0.631579</td>\n",
              "      <td>0.822281</td>\n",
              "      <td>0.658960</td>\n",
              "      <td>0.711989</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.060400</td>\n",
              "      <td>0.306788</td>\n",
              "      <td>0.728205</td>\n",
              "      <td>0.707071</td>\n",
              "      <td>0.836957</td>\n",
              "      <td>0.652406</td>\n",
              "      <td>0.731160</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3140454732.py:15: RuntimeWarning: invalid value encountered in divide\n",
            "  l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 07:58, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.255300</td>\n",
              "      <td>0.290941</td>\n",
              "      <td>0.610778</td>\n",
              "      <td>0.707317</td>\n",
              "      <td>0.809399</td>\n",
              "      <td>0.769953</td>\n",
              "      <td>0.724362</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.297800</td>\n",
              "      <td>0.268816</td>\n",
              "      <td>0.721519</td>\n",
              "      <td>0.735632</td>\n",
              "      <td>0.795970</td>\n",
              "      <td>0.777778</td>\n",
              "      <td>0.757725</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.116100</td>\n",
              "      <td>0.294590</td>\n",
              "      <td>0.690058</td>\n",
              "      <td>0.808511</td>\n",
              "      <td>0.806202</td>\n",
              "      <td>0.796380</td>\n",
              "      <td>0.775288</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3140454732.py:15: RuntimeWarning: invalid value encountered in divide\n",
            "  l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3140454732.py:15: RuntimeWarning: invalid value encountered in divide\n",
            "  l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n",
            "/tmp/ipython-input-2258470705.py:22: RuntimeWarning: invalid value encountered in divide\n",
            "  l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.7752876896788425 2e-05 12\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.69, 0.75, 0.72\n",
            "Right wing: 0.67, 0.82, 0.73\n",
            "Anti elitism: 0.83, 0.87, 0.85\n",
            "People centrism: 0.67, 0.76, 0.71\n",
            "Micro: 0.74, 0.81, 0.78\n",
            "Macro: 0.71, 0.80, 0.75\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:02, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.250800</td>\n",
              "      <td>0.288412</td>\n",
              "      <td>0.603550</td>\n",
              "      <td>0.721649</td>\n",
              "      <td>0.792839</td>\n",
              "      <td>0.728111</td>\n",
              "      <td>0.711537</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.275800</td>\n",
              "      <td>0.285373</td>\n",
              "      <td>0.683871</td>\n",
              "      <td>0.747475</td>\n",
              "      <td>0.795812</td>\n",
              "      <td>0.764706</td>\n",
              "      <td>0.747966</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.145700</td>\n",
              "      <td>0.281557</td>\n",
              "      <td>0.697987</td>\n",
              "      <td>0.778947</td>\n",
              "      <td>0.818182</td>\n",
              "      <td>0.807175</td>\n",
              "      <td>0.775573</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/tmp/ipython-input-3140454732.py:18: RuntimeWarning: invalid value encountered in divide\n",
            "  l4_f1 = 2 * l4_precision * l4_recall / (l4_precision + l4_recall)\n"
          ]
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.775572662919114 1.75e-05 12\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.68, 0.77, 0.72\n",
            "Right wing: 0.70, 0.79, 0.74\n",
            "Anti elitism: 0.79, 0.90, 0.84\n",
            "People centrism: 0.69, 0.73, 0.71\n",
            "Micro: 0.74, 0.82, 0.78\n",
            "Macro: 0.72, 0.80, 0.75\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:01, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.296300</td>\n",
              "      <td>0.291777</td>\n",
              "      <td>0.606061</td>\n",
              "      <td>0.686869</td>\n",
              "      <td>0.810127</td>\n",
              "      <td>0.752294</td>\n",
              "      <td>0.713837</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.267600</td>\n",
              "      <td>0.274961</td>\n",
              "      <td>0.707317</td>\n",
              "      <td>0.762887</td>\n",
              "      <td>0.815190</td>\n",
              "      <td>0.772093</td>\n",
              "      <td>0.764372</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.147500</td>\n",
              "      <td>0.277243</td>\n",
              "      <td>0.705882</td>\n",
              "      <td>0.791667</td>\n",
              "      <td>0.825641</td>\n",
              "      <td>0.796020</td>\n",
              "      <td>0.779802</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.7798024864365953 1.5e-05 12\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.67, 0.77, 0.72\n",
            "Right wing: 0.72, 0.80, 0.76\n",
            "Anti elitism: 0.81, 0.88, 0.84\n",
            "People centrism: 0.66, 0.79, 0.72\n",
            "Micro: 0.74, 0.83, 0.78\n",
            "Macro: 0.71, 0.81, 0.76\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:02, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.241100</td>\n",
              "      <td>0.269356</td>\n",
              "      <td>0.690909</td>\n",
              "      <td>0.617021</td>\n",
              "      <td>0.846348</td>\n",
              "      <td>0.641148</td>\n",
              "      <td>0.698857</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.184900</td>\n",
              "      <td>0.253729</td>\n",
              "      <td>0.706522</td>\n",
              "      <td>0.709677</td>\n",
              "      <td>0.839050</td>\n",
              "      <td>0.656716</td>\n",
              "      <td>0.727991</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.146300</td>\n",
              "      <td>0.277375</td>\n",
              "      <td>0.703911</td>\n",
              "      <td>0.744681</td>\n",
              "      <td>0.852041</td>\n",
              "      <td>0.654028</td>\n",
              "      <td>0.738665</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.7386651794836143 2e-05 13\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.69, 0.70, 0.70\n",
            "Right wing: 0.69, 0.69, 0.69\n",
            "Anti elitism: 0.78, 0.88, 0.83\n",
            "People centrism: 0.66, 0.81, 0.73\n",
            "Micro: 0.73, 0.81, 0.76\n",
            "Macro: 0.71, 0.77, 0.74\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:00, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.258500</td>\n",
              "      <td>0.261916</td>\n",
              "      <td>0.690647</td>\n",
              "      <td>0.639175</td>\n",
              "      <td>0.838542</td>\n",
              "      <td>0.680851</td>\n",
              "      <td>0.712304</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.146500</td>\n",
              "      <td>0.251186</td>\n",
              "      <td>0.700637</td>\n",
              "      <td>0.666667</td>\n",
              "      <td>0.853598</td>\n",
              "      <td>0.670000</td>\n",
              "      <td>0.722725</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.131400</td>\n",
              "      <td>0.271667</td>\n",
              "      <td>0.704663</td>\n",
              "      <td>0.765432</td>\n",
              "      <td>0.862245</td>\n",
              "      <td>0.662791</td>\n",
              "      <td>0.748783</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "======================\n",
            "0.7487827267085668 1.75e-05 13\n",
            "Metric:  Precision, Recall, F1\n",
            "Left wing: 0.77, 0.67, 0.72\n",
            "Right wing: 0.74, 0.70, 0.72\n",
            "Anti elitism: 0.81, 0.87, 0.84\n",
            "People centrism: 0.68, 0.77, 0.72\n",
            "Micro: 0.76, 0.79, 0.78\n",
            "Macro: 0.75, 0.75, 0.75\n",
            "======================\n"
          ]
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
            "  warnings.warn(\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
            "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
            "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='2451' max='2451' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [2451/2451 08:05, Epoch 3/3]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Epoch</th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Validation Loss</th>\n",
              "      <th>Left Wing F1</th>\n",
              "      <th>Right Wing F1</th>\n",
              "      <th>Anti Elitism F1</th>\n",
              "      <th>People Centrism F1</th>\n",
              "      <th>Macro F1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>0.285400</td>\n",
              "      <td>0.252576</td>\n",
              "      <td>0.712500</td>\n",
              "      <td>0.647619</td>\n",
              "      <td>0.848485</td>\n",
              "      <td>0.636364</td>\n",
              "      <td>0.711242</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>0.152700</td>\n",
              "      <td>0.245210</td>\n",
              "      <td>0.727273</td>\n",
              "      <td>0.716418</td>\n",
              "      <td>0.859296</td>\n",
              "      <td>0.636872</td>\n",
              "      <td>0.734965</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>0.131700</td>\n",
              "      <td>0.267744</td>\n",
              "      <td>0.714286</td>\n",
              "      <td>0.750000</td>\n",
              "      <td>0.872340</td>\n",
              "      <td>0.617801</td>\n",
              "      <td>0.738607</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import numpy as np\n",
        "from sklearn.metrics import f1_score, accuracy_score, precision_recall_fscore_support, precision_score, recall_score, precision_recall_curve\n",
        "\n",
        "# Define custom compute_metrics function\n",
        "def compute_metrics(pred):\n",
        "    logits, labels = pred\n",
        "\n",
        "    l1_precision, l1_recall, l1_thresholds = precision_recall_curve(labels[:, 0], logits[:, 0])\n",
        "    l1_f1 = 2 * l1_precision * l1_recall / (l1_precision + l1_recall)\n",
        "\n",
        "    l2_precision, l2_recall, l2_thresholds = precision_recall_curve(labels[:, 1], logits[:, 1])\n",
        "    l2_f1 = 2 * l2_precision * l2_recall / (l2_precision + l2_recall)\n",
        "\n",
        "    l3_precision, l3_recall, l3_thresholds = precision_recall_curve(labels[:, 2], logits[:, 2])\n",
        "    l3_f1 = 2 * l3_precision * l3_recall / (l3_precision + l3_recall)\n",
        "\n",
        "    l4_precision, l4_recall, l4_thresholds = precision_recall_curve(labels[:, 3], logits[:, 3])\n",
        "    l4_f1 = 2 * l4_precision * l4_recall / (l4_precision + l4_recall)\n",
        "\n",
        "    return {\n",
        "        \"left_wing_f1\": max(l1_f1),\n",
        "        \"right_wing_f1\": max(l2_f1),\n",
        "        \"anti_elitism_f1\": max(l3_f1),\n",
        "        \"people_centrism_f1\": max(l4_f1),\n",
        "        \"macro_f1\": (max(l1_f1)+max(l2_f1)+max(l3_f1)+max(l4_f1))/4,\n",
        "    }\n",
        "\n",
        "seeds = range(11,14)\n",
        "epochs = 3\n",
        "records = defaultdict(list)\n",
        "for seed in seeds:\n",
        "  np.random.seed(seed)\n",
        "  torch.manual_seed(seed)\n",
        "  random.seed(seed)\n",
        "\n",
        "  # 1759 (20%) for testing\n",
        "  shuffled_data = data.sample(random_state = seed, frac = 1)\n",
        "  test_data = shuffled_data[:1759]\n",
        "  dev_data = shuffled_data[1759:(1759 + 500)]\n",
        "  train_data = shuffled_data[(1759 + 500):]\n",
        "\n",
        "  test_labels = list(zip(test_data[\"is_left_wing\"].tolist(), test_data[\"is_right_wing\"].tolist(), test_data[\"is_anti_elitism\"].tolist(), test_data[\"is_people_centrism\"].tolist()))\n",
        "  test_texts = test_data[\"text\"].tolist()\n",
        "\n",
        "  train_labels = list(zip(train_data[\"is_left_wing\"].tolist(), train_data[\"is_right_wing\"].tolist(), train_data[\"is_anti_elitism\"].tolist(), train_data[\"is_people_centrism\"].tolist()))\n",
        "  train_texts = train_data[\"text\"].tolist()\n",
        "\n",
        "  dev_labels = list(zip(dev_data[\"is_left_wing\"].tolist(), dev_data[\"is_right_wing\"].tolist(), dev_data[\"is_anti_elitism\"].tolist(), dev_data[\"is_people_centrism\"].tolist()))\n",
        "  dev_texts = dev_data[\"text\"].tolist()\n",
        "\n",
        "  zipped = list(zip(train_texts, train_labels))\n",
        "  random.shuffle(zipped)\n",
        "  train_texts, train_labels = zip(*zipped)\n",
        "\n",
        "  train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=mlength)\n",
        "  dev_encodings = tokenizer(dev_texts, truncation=True, padding=True, max_length = mlength)\n",
        "  test_encodings = tokenizer(test_texts, truncation=True, padding=True, max_length= mlength)\n",
        "\n",
        "  train_dataset = PSCDataset(train_encodings, train_labels)\n",
        "  dev_dataset = PSCDataset(dev_encodings, dev_labels)\n",
        "  test_dataset = PSCDataset(test_encodings, test_labels)\n",
        "\n",
        "  from transformers import RobertaForSequenceClassification, TrainingArguments, Trainer\n",
        "  max_dev_f1 = 0\n",
        "  test_f1 = 0\n",
        "  for learning_rate in [20e-6, 175e-7, 15e-6]:\n",
        "\n",
        "    training_args = TrainingArguments(\n",
        "        output_dir=\"./results\",          # output directory\n",
        "        num_train_epochs=epochs,         # total number of training epochs\n",
        "        per_device_train_batch_size=8,   # batch size per device during training\n",
        "        per_device_eval_batch_size=64,   # batch size for evaluation\n",
        "        warmup_steps=0,                  # number of warmup steps for learning rate scheduler\n",
        "        weight_decay=0.01,               # strength of weight decay\n",
        "        logging_dir='./logs',            # directory for storing logs\n",
        "        logging_steps=10,\n",
        "        learning_rate = learning_rate,\n",
        "        save_strategy= \"epoch\",\n",
        "        evaluation_strategy=\"epoch\",\n",
        "        load_best_model_at_end= True,\n",
        "        metric_for_best_model=\"macro_f1\",\n",
        "        save_total_limit = 2,\n",
        "        seed = seed,\n",
        "        report_to = \"none\",\n",
        "    )\n",
        "\n",
        "    def model_init():\n",
        "        return RobertaForSequenceClassification.from_pretrained(\"roberta-large\", num_labels=nclasses, problem_type=\"multi_label_classification\")\n",
        "\n",
        "    trainer = Trainer(\n",
        "        model_init=model_init,               # the instantiated 🤗 Transformers model to be trained\n",
        "        args=training_args,                  # training arguments, defined above\n",
        "        train_dataset=train_dataset,         # training dataset\n",
        "        eval_dataset=dev_dataset,            # evaluation dataset\n",
        "        compute_metrics=compute_metrics,     # compute_metrics\n",
        "        )\n",
        "\n",
        "    trainer.train()\n",
        "    # dev accuracy\n",
        "    predictions = trainer.predict(dev_dataset)\n",
        "    logits = predictions.predictions\n",
        "    labels = dev_dataset.labels\n",
        "    dev_macro_f1 = metric_helper(labels, logits)[-1]\n",
        "\n",
        "    if dev_macro_f1 > max_dev_f1: # update test accuracy once dev accuracy improves\n",
        "      max_dev_f1 = dev_macro_f1\n",
        "      predictions = trainer.predict(test_dataset)\n",
        "      logits = predictions.predictions\n",
        "      labels = test_dataset.labels\n",
        "      max_l1_precision, max_l1_recall, max_l1_f1, max_l2_precision, max_l2_recall, max_l2_f1, max_l3_precision, max_l3_recall, max_l3_f1, max_l4_precision, max_l4_recall, max_l4_f1, micro_precision, micro_recall, micro_f1, macro_precision, macro_recall, macro_f1 = metric_helper(labels, logits)\n",
        "      print(\"======================\")\n",
        "      print(max_dev_f1, learning_rate, seed)\n",
        "      print(\"Metric:  Precision, Recall, F1\")\n",
        "      print(f\"Left wing: {max_l1_precision:.2f}, {max_l1_recall:.2f}, {max_l1_f1:.2f}\")\n",
        "      print(f\"Right wing: {max_l2_precision:.2f}, {max_l2_recall:.2f}, {max_l2_f1:.2f}\")\n",
        "      print(f\"Anti elitism: {max_l3_precision:.2f}, {max_l3_recall:.2f}, {max_l3_f1:.2f}\")\n",
        "      print(f\"People centrism: {max_l4_precision:.2f}, {max_l4_recall:.2f}, {max_l4_f1:.2f}\")\n",
        "      print(f\"Micro: {micro_precision:.2f}, {micro_recall:.2f}, {micro_f1:.2f}\")\n",
        "      print(f\"Macro: {macro_precision:.2f}, {macro_recall:.2f}, {macro_f1:.2f}\")\n",
        "      print(\"======================\")\n",
        "      records[seed] = [max_l1_precision, max_l1_recall, max_l1_f1, max_l2_precision, max_l2_recall, max_l2_f1, max_l3_precision, max_l3_recall, max_l3_f1, max_l4_precision, max_l4_recall, max_l4_f1, micro_precision, micro_recall, micro_f1, macro_precision, macro_recall, macro_f1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "WyMgg6PM03ck",
        "outputId": "55603d68-4c6f-4215-8457-33ccf998f1bb"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "11 [np.float64(0.7061068702290076), np.float64(0.72265625), np.float64(0.7142857142857143), np.float64(0.722972972972973), np.float64(0.7086092715231788), np.float64(0.7157190635451505), np.float64(0.797768479776848), np.float64(0.8773006134969326), np.float64(0.835646457268079), np.float64(0.6352357320099256), np.float64(0.8126984126984127), np.float64(0.713091922005571), np.float64(0.7320261437908496), np.float64(0.8151382823871907), np.float64(0.7713498622589531), np.float64(0.7155210137471886), np.float64(0.780316136929631), np.float64(0.7446857892761287)]\n",
            "12 [np.float64(0.6666666666666666), np.float64(0.7739463601532567), np.float64(0.7163120567375887), np.float64(0.7164948453608248), np.float64(0.8034682080924855), np.float64(0.757493188010899), np.float64(0.8113207547169812), np.float64(0.8789308176100629), np.float64(0.8437735849056605), np.float64(0.6622340425531915), np.float64(0.7904761904761904), np.float64(0.7206946454413893), np.float64(0.735595390524968), np.float64(0.8296028880866426), np.float64(0.7797760434340008), np.float64(0.714179077324416), np.float64(0.8117053940829988), np.float64(0.7595683687738843)]\n",
            "13 [np.float64(0.7698744769874477), np.float64(0.6690909090909091), np.float64(0.715953307392996), np.float64(0.738562091503268), np.float64(0.6975308641975309), np.float64(0.7174603174603175), np.float64(0.8118668596237337), np.float64(0.8711180124223602), np.float64(0.8404494382022472), np.float64(0.6814621409921671), np.float64(0.7676470588235295), np.float64(0.7219917012448134), np.float64(0.7633015006821282), np.float64(0.7874736101337086), np.float64(0.7751991686872186), np.float64(0.7504413922766542), np.float64(0.7513467111335823), np.float64(0.7489636910750934)]\n",
            "[0.714216   0.72189784 0.71551703 0.72600997 0.73653611 0.73022419\n",
            " 0.80698536 0.87578315 0.83995649 0.65964397 0.79027389 0.71859276\n",
            " 0.74364101 0.81073826 0.77544169 0.72671383 0.78112275 0.75107262]\n"
          ]
        }
      ],
      "source": [
        "for seed, metrics in records.items():\n",
        "  print(seed, metrics)\n",
        "print(np.mean(list(records.values()), axis = 0))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "machine_shape": "hm",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}